import torch
import numpy as np
import pickle
from PIL import Image
import os
import torchvision
import random
cpath = os.path.dirname(__file__)
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from opacus import PrivacyEngine

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Resize, InterpolationMode
from torch.utils.data import Dataset, DataLoader

class CIFAR10_CNN(nn.Module):
    def __init__(self, in_channels=3, input_norm=None, **kwargs):
        super(CIFAR10_CNN, self).__init__()
        self.in_channels = in_channels
        self.features = None
        self.classifier = None
        self.norm = None
        self.build(input_norm, **kwargs)

    def build(self, input_norm=None, num_groups=None,
              bn_stats=None, size=None):
        if self.in_channels == 3:
            if size == "small":
                cfg = [16, 16, 'M', 32, 32, 'M', 64, 'M']
            else:
                cfg = [32, 32, 'M', 64, 64, 'M', 128, 128, 'M']
            self.norm = nn.Identity()
        else:
            if size == "small":
                cfg = [16, 16, 'M', 32, 32]
            else:
                cfg = [64, 'M', 64]
            if input_norm is None:
                self.norm = nn.Identity()
            elif input_norm == "GroupNorm":
                self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False)
            else:
                self.norm = lambda x: standardize(x, bn_stats)
        layers = []
        act = nn.Tanh

        c = self.in_channels
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(c, v, kernel_size=3, stride=1, padding=1)
                layers += [conv2d, act()]
                c = v

        self.features = nn.Sequential(*layers)
        if self.in_channels == 3:
            hidden = 128
            self.classifier = nn.Sequential(nn.Linear(2048, hidden), act(), nn.Linear(hidden, 10))
        else:
            self.classifier = nn.Linear(c * 4 * 4, 10)

    def get_flattened_size(self):
        x = torch.randn(1, 3, self.height, self.height)
        output = self.features(x)
        return output.view(output.size(0), -1).size(1)

    def forward(self, x):
        if self.in_channels != 3:
            x = self.norm(x.view(-1, self.in_channels, 8, 8))
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

## random seeds
torch.manual_seed(10)

## dataset with padding
num_padding = 0 ## 2-11%; 5-24%; 9-36%; 16-50%; 27- 63%; 48-75%
num_padding_aux = torch.tensor(np.floor(num_padding/(32+num_padding) *32),dtype = torch.int16)
batch_size = 1
#transform = transforms.Compose([transforms.Pad(num_padding),Resize(size=(32, 32), interpolation=InterpolationMode.BICUBIC),transforms.ToTensor()])
transform = transforms.Compose([transforms.Pad(num_padding),Resize(size=(32, 32), interpolation=InterpolationMode.BICUBIC),transforms.ToTensor()])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

class ConstructDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]

        return sample, label

train_dataset_padding = []
train_label_padding = []
test_dataset_padding = []
test_label_padding = []
for data, target in trainloader:
    for i in range(target.shape[0]):
        data[i,:,:,:num_padding_aux] = data[i,:,num_padding_aux,num_padding_aux].repeat(32,num_padding_aux,1).permute(2,0,1)
        data[i, :, :, 32-num_padding_aux:32] = data[i, :, num_padding_aux, num_padding_aux].repeat(32, num_padding_aux,
                                                                                      1).permute(2, 0, 1)
        data[i, :, :num_padding_aux,:] = data[i, :, num_padding_aux, num_padding_aux].repeat(32, num_padding_aux,
                                                                                      1).permute(2, 1, 0)
        data[i, :, 32-num_padding_aux:32, :] = data[i, :, num_padding_aux, num_padding_aux].repeat(32, num_padding_aux,
                                                                                      1).permute(2, 1, 0)

        train_dataset_padding.append(data.squeeze(0))
        train_label_padding.append(target)

plt.imshow(train_dataset_padding[18].permute(1,2,0))
plt.axis('off')
plt.show()

for data, target in testloader:
    for i in range(target.shape[0]):
        data[i, :, :, :num_padding_aux] = data[i, :, num_padding_aux, num_padding_aux].repeat(32, num_padding_aux,
                                                                                              1).permute(2, 0, 1)
        data[i, :, :, 32 - num_padding_aux:32] = data[i, :, num_padding_aux, num_padding_aux].repeat(32,
                                                                                                     num_padding_aux,
                                                                                                     1).permute(2, 0, 1)
        data[i, :, :num_padding_aux, :] = data[i, :, num_padding_aux, num_padding_aux].repeat(32, num_padding_aux,
                                                                                              1).permute(2, 1, 0)
        data[i, :, 32 - num_padding_aux:32, :] = data[i, :, num_padding_aux, num_padding_aux].repeat(32,
                                                                                                     num_padding_aux,
                                                                                                     1).permute(2, 1, 0)
        test_dataset_padding.append(data.squeeze(0))
        test_label_padding.append(target)

## dataset construction
batch_size = 256
train_dataset = ConstructDataset(train_dataset_padding,train_label_padding)
test_dataset = ConstructDataset(test_dataset_padding,test_label_padding)
trainloader_padding = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
testloader_padding = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## model training
model = CIFAR10_CNN().cuda()
num_iter = 50
lr = 2
sigma = 0.98
delta = 1e-5
max_per_sample_grad_norm = 0.1
original_optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0)
privacy_engine = PrivacyEngine(secure_mode=False)
model, optimizer, trainloader_padding = privacy_engine.make_private(
    module=model,
    optimizer=original_optimizer,
    data_loader=trainloader_padding,
    noise_multiplier=sigma,
    max_grad_norm=max_per_sample_grad_norm,
    )
criterion = nn.CrossEntropyLoss()
for iter in range(num_iter):
    model.train()
    running_loss = 0
    total = 0
    for data, target in trainloader_padding:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target.squeeze(1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()*target.squeeze(1).shape[0]
        total += target.squeeze(1).shape[0]
    print(f'Iteration {iter + 1} loss: {running_loss / total}')
    print(len(trainloader_padding))

    correct = 0
    total = 0
    test_loss = 0
    model.eval()
    for images, labels in testloader_padding:
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        test_loss += criterion(outputs,labels.squeeze(1))*labels.squeeze(1).shape[0]
        _, predicted = torch.max(outputs.data, 1)
        total += labels.squeeze(1).shape[0]
        correct += (predicted == labels.squeeze(1)).sum().item()
    test_acc = correct / total
    test_loss = test_loss / total
    print('test_loss:',test_loss)
    print('test_acc:', test_acc)
    epsilon = privacy_engine.get_epsilon(delta)
    print('privacy budget: ', epsilon)